{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "8d3a4214",
      "metadata": {
        "id": "8d3a4214"
      },
      "source": [
        "# Finetune on ScienceQA\n",
        "Let's use LLM Engine to fine-tune Llama-2 on ScienceQA!"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "XK6VpTnOL4OV",
      "metadata": {
        "id": "XK6VpTnOL4OV"
      },
      "source": [
        "# Packages Required\n",
        "For this demo, we'll be using the `scale-llm-engine` package and `datasets` from Huggingface.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "S5u6DdInMEQ7",
      "metadata": {
        "id": "S5u6DdInMEQ7"
      },
      "outputs": [],
      "source": [
        "!pip install scale-llm-engine\n",
        "!pip install datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a3dc2a56",
      "metadata": {
        "id": "a3dc2a56"
      },
      "source": [
        "# Data Preparation\n",
        "Let's load in the dataset using Huggingface and view the features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e06ac39e",
      "metadata": {
        "id": "e06ac39e"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "from smart_open import smart_open\n",
        "\n",
        "dataset = load_dataset('derek-thomas/ScienceQA')\n",
        "dataset['train'].features"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1cbe8a58",
      "metadata": {
        "id": "1cbe8a58"
      },
      "source": [
        "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0b0eb8ad",
      "metadata": {
        "id": "0b0eb8ad"
      },
      "outputs": [],
      "source": [
        "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n",
        "def format_options(options, choice_prefixes):\n",
        "    return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n",
        "\n",
        "def format_prompt(r, choice_prefixes):\n",
        "    options = format_options(r['choices'], choice_prefixes)\n",
        "    return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n",
        "\n",
        "def format_label(r, choice_prefixes):\n",
        "    return choice_prefixes[r['answer']]\n",
        "\n",
        "def convert_dataset(ds):\n",
        "    prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n",
        "    labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n",
        "    df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n",
        "    return df\n",
        "\n",
        "save_to_s3 = False\n",
        "df_train = convert_dataset(dataset['train'])\n",
        "if save_to_s3:\n",
        "    train_url = 's3://...'\n",
        "    val_url = 's3://...'\n",
        "    df_train = convert_dataset(dataset['train'])\n",
        "    with smart_open(train_url, 'wb') as f:\n",
        "        df_train.to_csv(f)\n",
        "\n",
        "    df_val = convert_dataset(dataset['validation'])\n",
        "    with smart_open(val_url, 'wb') as f:\n",
        "        df_val.to_csv(f)\n",
        "else:\n",
        "    # Gists of the already processed datasets\n",
        "    train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n",
        "    val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n",
        "\n",
        "df_train"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e2fc8d76",
      "metadata": {
        "id": "e2fc8d76"
      },
      "source": [
        "# Fine-tune\n",
        "Now, we can fine-tune the model using LLM Engine."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4905d447",
      "metadata": {
        "id": "4905d447"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "os.environ['SCALE_API_KEY'] = 'xxx'\n",
        "\n",
        "from llmengine import FineTune\n",
        "\n",
        "response = FineTune.create(\n",
        "    model=\"llama-2-7b\",\n",
        "    training_file=train_url,\n",
        "    validation_file=val_url,\n",
        "    hyperparameters={\n",
        "        'lr':2e-4,\n",
        "    },\n",
        "    suffix='science-qa-llama'\n",
        ")\n",
        "run_id = response.id"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "55074457",
      "metadata": {
        "id": "55074457"
      },
      "source": [
        "We can sleep until the job completes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "840938dd",
      "metadata": {
        "id": "840938dd"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "\n",
        "while True:\n",
        "    job_status = FineTune.get(run_id).status\n",
        "    print(job_status)\n",
        "    if job_status == 'SUCCESS':\n",
        "        break\n",
        "    time.sleep(60)\n",
        "\n",
        "fine_tuned_model = FineTune.get(run_id).fine_tuned_model"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "31278c6d",
      "metadata": {
        "id": "31278c6d"
      },
      "source": [
        "# Inference and Evaluation\n",
        "Let's evaluate the new fine-tuned model by running inference against it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3b9d7643",
      "metadata": {
        "id": "3b9d7643"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "from llmengine import Completion\n",
        "\n",
        "# Helper function to get outputs for fine-tuned model with retries\n",
        "def get_output(prompt: str, num_retry: int = 5):\n",
        "    for _ in range(num_retry):\n",
        "        try:\n",
        "            response = Completion.create(\n",
        "                model=fine_tuned_model,\n",
        "                prompt=prompt,\n",
        "                max_new_tokens=1,\n",
        "                temperature=0.01\n",
        "            )\n",
        "            return response.output.text.strip()\n",
        "        except Exception as e:\n",
        "            print(e)\n",
        "    return \"\"\n",
        "\n",
        "# Read the test data\n",
        "test = pd.read_csv(val_url)\n",
        "\n",
        "test[\"prediction\"] = test[\"prompt\"].apply(get_output)\n",
        "print(f\"Accuracy: {(test['response'] == test['prediction']).mean() * 100:.2f}%\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9f2f3f43",
      "metadata": {
        "id": "9f2f3f43"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Environment (conda_pytorch_p38)",
      "language": "python",
      "name": "conda_pytorch_p38"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
